{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "f_KNv25DX7B6"
   },
   "source": [
    "#[Super SloMo](https://people.cs.umass.edu/~hzjiang/projects/superslomo/)\n",
    "##High Quality Estimation of Multiple Intermediate Frames for Video Interpolation\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "0VWuBGh6zMMZ"
   },
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import torch\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms\n",
    "import torch.optim as optim\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import slomo_model as model\n",
    "import dataloader\n",
    "import matplotlib.pyplot as plt\n",
    "from math import log10\n",
    "from IPython.display import clear_output, display\n",
    "import datetime\n",
    "from tensorboardX import SummaryWriter"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "1VynXmoKp_3M"
   },
   "source": [
    "##Parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "N2yrOVZjqDe9"
   },
   "outputs": [],
   "source": [
    "# Learning Rate. Set `MILESTONES` to epoch values where you want to decrease\n",
    "# learning rate by a factor of 0.1\n",
    "INITIAL_LEARNING_RATE = 0.0001\n",
    "MILESTONES = [100, 150]\n",
    "\n",
    "# Number of epochs to train\n",
    "EPOCHS = 200\n",
    "\n",
    "# Choose batchsize as per GPU/CPU configuration\n",
    "# This configuration works on GTX 1080 Ti\n",
    "TRAIN_BATCH_SIZE = 6\n",
    "VALIDATION_BATCH_SIZE = 10\n",
    "\n",
    "# Path to dataset folder containing train-test-validation folders\n",
    "DATASET_ROOT = \"path/to/dataset\"\n",
    "\n",
    "# Path to folder for saving checkpoints\n",
    "CHECKPOINT_DIR = 'path/to/checkpoint_directory'\n",
    "\n",
    "# If resuming from checkpoint, set `trainingContinue` to True and set `checkpoint_path`\n",
    "TRAINING_CONTINUE = False\n",
    "CHECKPOINT_PATH = 'path/to/checkpoint/file'\n",
    "\n",
    "# Progress and validation frequency (N: after every N iterations)\n",
    "PROGRESS_ITER = 100\n",
    "\n",
    "# Checkpoint frequency (N: after every N epochs). Each checkpoint is roughly of size 151 MB.\n",
    "CHECKPOINT_EPOCH = 5"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "Yr3Lm1ovbWv1"
   },
   "source": [
    "##[TensorboardX](https://github.com/lanpa/tensorboardX)\n",
    "### For visualizing loss and interpolated frames"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "saUJTMiMCAzH"
   },
   "outputs": [],
   "source": [
    "writer = SummaryWriter('log')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "Ua1DJm82aj5-"
   },
   "source": [
    "###Initialize flow computation and arbitrary-time flow interpolation CNNs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "D42vzEKrWtpG"
   },
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "flowComp = model.UNet(6, 4)\n",
    "flowComp.to(device)\n",
    "ArbTimeFlowIntrp = model.UNet(20, 5)\n",
    "ArbTimeFlowIntrp.to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "UYMpk2EYchaY"
   },
   "source": [
    "###Initialze backward warpers for train and validation datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "vJq6SrWIf2GE"
   },
   "outputs": [],
   "source": [
    "trainFlowBackWarp      = model.backWarp(352, 352, device)\n",
    "trainFlowBackWarp      = trainFlowBackWarp.to(device)\n",
    "validationFlowBackWarp = model.backWarp(640, 352, device)\n",
    "validationFlowBackWarp = validationFlowBackWarp.to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "oSs9UaIjdTT2"
   },
   "source": [
    "###Load Datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "MJ9cVigEgtyT"
   },
   "outputs": [],
   "source": [
    "# Channel wise mean calculated on adobe240-fps training dataset\n",
    "mean = [0.429, 0.431, 0.397]\n",
    "std  = [1, 1, 1]\n",
    "normalize = transforms.Normalize(mean=mean,\n",
    "                                 std=std)\n",
    "transform = transforms.Compose([transforms.ToTensor(), normalize])\n",
    "\n",
    "trainset = dataloader.SuperSloMo(root=DATASET_ROOT + '/train', transform=transform, train=True)\n",
    "trainloader = torch.utils.data.DataLoader(trainset, batch_size=TRAIN_BATCH_SIZE, shuffle=True)\n",
    "\n",
    "validationset = dataloader.SuperSloMo(root=DATASET_ROOT + '/validation', transform=transform, randomCropSize=(640, 352), train=False)\n",
    "validationloader = torch.utils.data.DataLoader(validationset, batch_size=VALIDATION_BATCH_SIZE, shuffle=False)\n",
    "\n",
    "print(trainset, validationset)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "WXmNMdbJfp2d"
   },
   "source": [
    "###Create transform to display image from tensor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "try3adPHgwse"
   },
   "outputs": [],
   "source": [
    "negmean = [x * -1 for x in mean]\n",
    "revNormalize = transforms.Normalize(mean=negmean, std=std)\n",
    "TP = transforms.Compose([revNormalize, transforms.ToPILImage()])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "32XZg9Mfd5bN"
   },
   "source": [
    "###Test the dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "0Vyf7dbwCO1E"
   },
   "outputs": [],
   "source": [
    "for trainIndex, (trainData, frameIndex) in enumerate(trainloader, 0):\n",
    "    frame0, frameT, frame1 = trainData\n",
    "    print(\"Intermediate frame index: \", (frameIndex[0]))\n",
    "    plt.imshow(TP(frame0[0]))\n",
    "    plt.grid(True)\n",
    "    plt.figure()\n",
    "    plt.imshow(TP(frameT[0]))\n",
    "    plt.grid(True)\n",
    "    plt.figure()\n",
    "    plt.imshow(TP(frame1[0]))\n",
    "    plt.grid(True)\n",
    "    break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "rh0MK2qKuBlV"
   },
   "source": [
    "###Utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "BdMFU0ijfIuI"
   },
   "outputs": [],
   "source": [
    "plt.rcParams['figure.figsize'] = [15, 3]\n",
    "def Plot(num, listInp, d):\n",
    "    a = listInp\n",
    "    c = []\n",
    "    for b in a:\n",
    "        c.append(sum(b)/len(b))\n",
    "    plt.subplot(1, 2, num)\n",
    "    plt.plot(c, color=d)\n",
    "    plt.grid(True)\n",
    "    \n",
    "def get_lr(optimizer):\n",
    "    for param_group in optimizer.param_groups:\n",
    "        return param_group['lr']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "mooLcmxtpPR_"
   },
   "source": [
    "###Loss and Optimizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "BuWQfcb-jhWx"
   },
   "outputs": [],
   "source": [
    "L1_lossFn = nn.L1Loss()\n",
    "MSE_LossFn = nn.MSELoss()\n",
    "\n",
    "params = list(ArbTimeFlowIntrp.parameters()) + list(flowComp.parameters())\n",
    "\n",
    "optimizer = optim.Adam(params, lr=INITIAL_LEARNING_RATE)\n",
    "# scheduler to decrease learning rate by a factor of 10 at milestones.\n",
    "scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=MILESTONES, gamma=0.1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "a5rIkwwfpk1n"
   },
   "source": [
    "###Initializing VGG16 model for perceptual loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "9WR_NxHP51oB"
   },
   "outputs": [],
   "source": [
    "vgg16 = torchvision.models.vgg16(pretrained=True)\n",
    "vgg16_conv_4_3 = nn.Sequential(*list(vgg16.children())[0][:22])\n",
    "vgg16_conv_4_3.to(device)\n",
    "for param in vgg16_conv_4_3.parameters():\n",
    "\t\tparam.requires_grad = False"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "9-6wLaBJZqsm"
   },
   "source": [
    "### Validation function\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "RhMMZ_I4iDFf"
   },
   "outputs": [],
   "source": [
    "def validate():\n",
    "    # For details see training.\n",
    "    psnr = 0\n",
    "    tloss = 0\n",
    "    flag = 1\n",
    "    with torch.no_grad():\n",
    "        for validationIndex, (validationData, validationFrameIndex) in enumerate(validationloader, 0):\n",
    "            frame0, frameT, frame1 = validationData\n",
    "\n",
    "            I0 = frame0.to(device)\n",
    "            I1 = frame1.to(device)\n",
    "            IFrame = frameT.to(device)\n",
    "                        \n",
    "            \n",
    "            flowOut = flowComp(torch.cat((I0, I1), dim=1))\n",
    "            F_0_1 = flowOut[:,:2,:,:]\n",
    "            F_1_0 = flowOut[:,2:,:,:]\n",
    "\n",
    "            fCoeff = model.getFlowCoeff(validationFrameIndex, device)\n",
    "\n",
    "            F_t_0 = fCoeff[0] * F_0_1 + fCoeff[1] * F_1_0\n",
    "            F_t_1 = fCoeff[2] * F_0_1 + fCoeff[3] * F_1_0\n",
    "\n",
    "            g_I0_F_t_0 = validationFlowBackWarp(I0, F_t_0)\n",
    "            g_I1_F_t_1 = validationFlowBackWarp(I1, F_t_1)\n",
    "            \n",
    "            intrpOut = ArbTimeFlowIntrp(torch.cat((I0, I1, F_0_1, F_1_0, F_t_1, F_t_0, g_I1_F_t_1, g_I0_F_t_0), dim=1))\n",
    "                \n",
    "            F_t_0_f = intrpOut[:, :2, :, :] + F_t_0\n",
    "            F_t_1_f = intrpOut[:, 2:4, :, :] + F_t_1\n",
    "            V_t_0   = F.sigmoid(intrpOut[:, 4:5, :, :])\n",
    "            V_t_1   = 1 - V_t_0\n",
    "                \n",
    "            g_I0_F_t_0_f = validationFlowBackWarp(I0, F_t_0_f)\n",
    "            g_I1_F_t_1_f = validationFlowBackWarp(I1, F_t_1_f)\n",
    "            \n",
    "            wCoeff = model.getWarpCoeff(validationFrameIndex, device)\n",
    "            \n",
    "            Ft_p = (wCoeff[0] * V_t_0 * g_I0_F_t_0_f + wCoeff[1] * V_t_1 * g_I1_F_t_1_f) / (wCoeff[0] * V_t_0 + wCoeff[1] * V_t_1)\n",
    "            \n",
    "            # For tensorboard\n",
    "            if (flag):\n",
    "                retImg = torchvision.utils.make_grid([revNormalize(frame0[0]), revNormalize(frameT[0]), revNormalize(Ft_p.cpu()[0]), revNormalize(frame1[0])], padding=10)\n",
    "                flag = 0\n",
    "            \n",
    "            \n",
    "            #loss\n",
    "            recnLoss = L1_lossFn(Ft_p, IFrame)\n",
    "            \n",
    "            prcpLoss = MSE_LossFn(vgg16_conv_4_3(Ft_p), vgg16_conv_4_3(IFrame))\n",
    "            \n",
    "            warpLoss = L1_lossFn(g_I0_F_t_0, IFrame) + L1_lossFn(g_I1_F_t_1, IFrame) + L1_lossFn(validationFlowBackWarp(I0, F_1_0), I1) + L1_lossFn(validationFlowBackWarp(I1, F_0_1), I0)\n",
    "        \n",
    "            loss_smooth_1_0 = torch.mean(torch.abs(F_1_0[:, :, :, :-1] - F_1_0[:, :, :, 1:])) + torch.mean(torch.abs(F_1_0[:, :, :-1, :] - F_1_0[:, :, 1:, :]))\n",
    "            loss_smooth_0_1 = torch.mean(torch.abs(F_0_1[:, :, :, :-1] - F_0_1[:, :, :, 1:])) + torch.mean(torch.abs(F_0_1[:, :, :-1, :] - F_0_1[:, :, 1:, :]))\n",
    "            loss_smooth = loss_smooth_1_0 + loss_smooth_0_1\n",
    "            \n",
    "            \n",
    "            loss = 204 * recnLoss + 102 * warpLoss + 0.005 * prcpLoss + loss_smooth\n",
    "            tloss += loss.item()\n",
    "            \n",
    "            #psnr\n",
    "            MSE_val = MSE_LossFn(Ft_p, IFrame)\n",
    "            psnr += (10 * log10(1 / MSE_val.item()))\n",
    "            \n",
    "    return (psnr / len(validationloader)), (tloss / len(validationloader)), retImg"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "Eh1LB1ufZziF"
   },
   "source": [
    "### Test validation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "axBjslWlot7I"
   },
   "outputs": [],
   "source": [
    "a, b, c = validate()\n",
    "print(a, b, c.size())\n",
    "plt.imshow(c.permute(1, 2, 0).numpy())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "1PIFbXuKpBBe"
   },
   "source": [
    "### Initialization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "gWt-nlx2MSOk"
   },
   "outputs": [],
   "source": [
    "if TRAINING_CONTINUE:\n",
    "    dict1 = torch.load(CHECKPOINT_PATH)\n",
    "    ArbTimeFlowIntrp.load_state_dict(dict1['state_dictAT'])\n",
    "    flowComp.load_state_dict(dict1['state_dictFC'])\n",
    "else:\n",
    "    dict1 = {'loss': [], 'valLoss': [], 'valPSNR': [], 'epoch': -1}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "RbQnS_KNilbR"
   },
   "source": [
    "### Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "both",
    "colab": {},
    "colab_type": "code",
    "id": "QrAS6TmP11RW"
   },
   "outputs": [],
   "source": [
    "import time\n",
    "\n",
    "start = time.time()\n",
    "cLoss   = dict1['loss']\n",
    "valLoss = dict1['valLoss']\n",
    "valPSNR = dict1['valPSNR']\n",
    "checkpoint_counter = 0\n",
    "\n",
    "### Main training loop\n",
    "for epoch in range(dict1['epoch'] + 1, EPOCHS):\n",
    "    clear_output()\n",
    "    print(\"Epoch: \", epoch)\n",
    "    \n",
    "    # Plots\n",
    "    if (epoch):\n",
    "        Plot(1, cLoss, 'red')\n",
    "        Plot(1, valLoss, 'blue')\n",
    "        Plot(2, valPSNR, 'green')\n",
    "        display(plt.gcf())\n",
    "    \n",
    "    # Append and reset\n",
    "    cLoss.append([])\n",
    "    valLoss.append([])\n",
    "    valPSNR.append([])\n",
    "    iLoss = 0\n",
    "    \n",
    "    # Increment scheduler count    \n",
    "    scheduler.step()\n",
    "    \n",
    "    for trainIndex, (trainData, trainFrameIndex) in enumerate(trainloader, 0):\n",
    "        \n",
    "\t\t## Getting the input and the target from the training set\n",
    "        frame0, frameT, frame1 = trainData\n",
    "        \n",
    "        I0 = frame0.to(device)\n",
    "        I1 = frame1.to(device)\n",
    "        IFrame = frameT.to(device)\n",
    "        \n",
    "        optimizer.zero_grad()\n",
    "        \n",
    "        # Calculate flow between reference frames I0 and I1\n",
    "        flowOut = flowComp(torch.cat((I0, I1), dim=1))\n",
    "        \n",
    "        # Extracting flows between I0 and I1 - F_0_1 and F_1_0\n",
    "        F_0_1 = flowOut[:,:2,:,:]\n",
    "        F_1_0 = flowOut[:,2:,:,:]\n",
    "        \n",
    "        fCoeff = model.getFlowCoeff(trainFrameIndex, device)\n",
    "        \n",
    "        # Calculate intermediate flows\n",
    "        F_t_0 = fCoeff[0] * F_0_1 + fCoeff[1] * F_1_0\n",
    "        F_t_1 = fCoeff[2] * F_0_1 + fCoeff[3] * F_1_0\n",
    "        \n",
    "        # Get intermediate frames from the intermediate flows\n",
    "        g_I0_F_t_0 = trainFlowBackWarp(I0, F_t_0)\n",
    "        g_I1_F_t_1 = trainFlowBackWarp(I1, F_t_1)\n",
    "        \n",
    "        # Calculate optical flow residuals and visibility maps\n",
    "        intrpOut = ArbTimeFlowIntrp(torch.cat((I0, I1, F_0_1, F_1_0, F_t_1, F_t_0, g_I1_F_t_1, g_I0_F_t_0), dim=1))\n",
    "        \n",
    "        # Extract optical flow residuals and visibility maps\n",
    "        F_t_0_f = intrpOut[:, :2, :, :] + F_t_0\n",
    "        F_t_1_f = intrpOut[:, 2:4, :, :] + F_t_1\n",
    "        V_t_0   = F.sigmoid(intrpOut[:, 4:5, :, :])\n",
    "        V_t_1   = 1 - V_t_0\n",
    "        \n",
    "        # Get intermediate frames from the intermediate flows\n",
    "        g_I0_F_t_0_f = trainFlowBackWarp(I0, F_t_0_f)\n",
    "        g_I1_F_t_1_f = trainFlowBackWarp(I1, F_t_1_f)\n",
    "        \n",
    "        wCoeff = model.getWarpCoeff(trainFrameIndex, device)\n",
    "        \n",
    "        # Calculate final intermediate frame \n",
    "        Ft_p = (wCoeff[0] * V_t_0 * g_I0_F_t_0_f + wCoeff[1] * V_t_1 * g_I1_F_t_1_f) / (wCoeff[0] * V_t_0 + wCoeff[1] * V_t_1)\n",
    "        \n",
    "        # Loss\n",
    "        recnLoss = L1_lossFn(Ft_p, IFrame)\n",
    "            \n",
    "        prcpLoss = MSE_LossFn(vgg16_conv_4_3(Ft_p), vgg16_conv_4_3(IFrame))\n",
    "        \n",
    "        warpLoss = L1_lossFn(g_I0_F_t_0, IFrame) + L1_lossFn(g_I1_F_t_1, IFrame) + L1_lossFn(trainFlowBackWarp(I0, F_1_0), I1) + L1_lossFn(trainFlowBackWarp(I1, F_0_1), I0)\n",
    "        \n",
    "        loss_smooth_1_0 = torch.mean(torch.abs(F_1_0[:, :, :, :-1] - F_1_0[:, :, :, 1:])) + torch.mean(torch.abs(F_1_0[:, :, :-1, :] - F_1_0[:, :, 1:, :]))\n",
    "        loss_smooth_0_1 = torch.mean(torch.abs(F_0_1[:, :, :, :-1] - F_0_1[:, :, :, 1:])) + torch.mean(torch.abs(F_0_1[:, :, :-1, :] - F_0_1[:, :, 1:, :]))\n",
    "        loss_smooth = loss_smooth_1_0 + loss_smooth_0_1\n",
    "          \n",
    "        # Total Loss - Coefficients 204 and 102 are used instead of 0.8 and 0.4\n",
    "        # since the loss in paper is calculated for input pixels in range 0-255\n",
    "        # and the input to our network is in range 0-1\n",
    "        loss = 204 * recnLoss + 102 * warpLoss + 0.005 * prcpLoss + loss_smooth\n",
    "        \n",
    "        # Backpropagate\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        iLoss += loss.item()\n",
    "               \n",
    "        # Validation and progress every `PROGRESS_ITER` iterations\n",
    "        if ((trainIndex % PROGRESS_ITER) == PROGRESS_ITER - 1):\n",
    "            end = time.time()\n",
    "            \n",
    "            psnr, vLoss, valImg = validate()\n",
    "            \n",
    "            valPSNR[epoch].append(psnr)\n",
    "            valLoss[epoch].append(vLoss)\n",
    "            \n",
    "            #Tensorboard\n",
    "            itr = trainIndex + epoch * (len(trainloader))\n",
    "            \n",
    "            writer.add_scalars('Loss', {'trainLoss': iLoss/PROGRESS_ITER,\n",
    "                                        'validationLoss': vLoss}, itr)\n",
    "            writer.add_scalar('PSNR', psnr, itr)\n",
    "            \n",
    "            writer.add_image('Validation',valImg , itr)\n",
    "            #####\n",
    "            \n",
    "            endVal = time.time()\n",
    "            \n",
    "            print(\" Loss: %0.6f  Iterations: %4d/%4d  TrainExecTime: %0.1f  ValLoss:%0.6f  ValPSNR: %0.4f  ValEvalTime: %0.2f LearningRate: %f\" % (iLoss / PROGRESS_ITER, trainIndex, len(trainloader), end - start, vLoss, psnr, endVal - end, get_lr(optimizer)))\n",
    "            \n",
    "            \n",
    "            cLoss[epoch].append(iLoss/PROGRESS_ITER)\n",
    "            iLoss = 0\n",
    "            start = time.time()\n",
    "    \n",
    "    # Create checkpoint after every `CHECKPOINT_EPOCH` epochs\n",
    "    if ((epoch % CHECKPOINT_EPOCH) == CHECKPOINT_EPOCH - 1):\n",
    "        dict1 = {\n",
    "                'Detail':\"End to end Super SloMo.\",\n",
    "                'epoch':epoch,\n",
    "                'timestamp':datetime.datetime.now(),\n",
    "                'trainBatchSz':TRAIN_BATCH_SIZE,\n",
    "                'validationBatchSz':VALIDATION_BATCH_SIZE,\n",
    "                'learningRate':get_lr(optimizer),\n",
    "                'loss':cLoss,\n",
    "                'valLoss':valLoss,\n",
    "                'valPSNR':valPSNR,\n",
    "                'state_dictFC': flowComp.state_dict(),\n",
    "                'state_dictAT': ArbTimeFlowIntrp.state_dict(),\n",
    "                }\n",
    "        torch.save(dict1, CHECKPOINT_DIR + \"/SuperSloMo\" + str(checkpoint_counter) + \".ckpt\")\n",
    "        checkpoint_counter += 1\n",
    "    plt.close('all')"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "name": "train.ipynb",
   "provenance": [],
   "version": "0.3.2"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}